# RUN: tf-mlir-translate -graphdef-to-mlir -tf-enable-shape-inference-on-import=false %s -tf-input-arrays=z:1,z:2 -tf-input-shapes=':' -tf-output-arrays=z:2,z:1,a:0 -o - | FileCheck %s
# RUN: tf-mlir-translate -graphdef-to-mlir -tf-enable-shape-inference-on-import=false %s -tf-prune-unused-nodes -tf-input-arrays=z:1,z:2 -tf-input-shapes=':' -tf-output-arrays=z:2,z:1,a:0 -o - | FileCheck --check-prefix=PRUNE %s
# RUN: tf-mlir-translate -graphdef-to-mlir -tf-enable-shape-inference-on-import=false %s -tf-prune-unused-nodes -tf-input-arrays=z:1,z:2 -tf-input-shapes=':' -tf-output-arrays=z:0,a:0 -o - | FileCheck --check-prefix=PRESERVE %s

# Generated in Python via
# ```
# import tensorflow as tf
#
# with tf.compat.v1.Graph().as_default() as g:
#   w = tf.constant(2.0)
#   x = tf.constant(3.0)
#   y = tf.constant(4.0)
#   var = tf.Variable(2.0)
#   var_add = var.assign_add(1.0)
#   with g.control_dependencies([var_add]):
#     z0, z1, z2 = tf.identity_n((w, x, y))
#
#   a = tf.add(z1, z2)
# ```

node {
  name: "w"
  op: "Const"
  attr {
    key: "dtype"
    value {
      type: DT_FLOAT
    }
  }
  attr {
    key: "value"
    value {
      tensor {
        dtype: DT_FLOAT
        tensor_shape {
        }
        float_val: 2.0
      }
    }
  }
}
node {
  name: "x"
  op: "Const"
  attr {
    key: "dtype"
    value {
      type: DT_FLOAT
    }
  }
  attr {
    key: "value"
    value {
      tensor {
        dtype: DT_FLOAT
        tensor_shape {
        }
        float_val: 3.0
      }
    }
  }
}
node {
  name: "y"
  op: "Const"
  attr {
    key: "dtype"
    value {
      type: DT_FLOAT
    }
  }
  attr {
    key: "value"
    value {
      tensor {
        dtype: DT_FLOAT
        tensor_shape {
        }
        float_val: 4.0
      }
    }
  }
}
node {
  name: "var/initial_value"
  op: "Const"
  attr {
    key: "dtype"
    value {
      type: DT_FLOAT
    }
  }
  attr {
    key: "value"
    value {
      tensor {
        dtype: DT_FLOAT
        tensor_shape {
        }
        float_val: 2.0
      }
    }
  }
}
node {
  name: "var"
  op: "VariableV2"
  attr {
    key: "container"
    value {
      s: ""
    }
  }
  attr {
    key: "dtype"
    value {
      type: DT_FLOAT
    }
  }
  attr {
    key: "shape"
    value {
      shape {
      }
    }
  }
  attr {
    key: "shared_name"
    value {
      s: ""
    }
  }
}
node {
  name: "var/Assign"
  op: "Assign"
  input: "var"
  input: "var/initial_value"
  attr {
    key: "T"
    value {
      type: DT_FLOAT
    }
  }
  attr {
    key: "_class"
    value {
      list {
        s: "loc:@var"
      }
    }
  }
  attr {
    key: "use_locking"
    value {
      b: true
    }
  }
  attr {
    key: "validate_shape"
    value {
      b: true
    }
  }
}
node {
  name: "var/read"
  op: "Identity"
  input: "var"
  attr {
    key: "T"
    value {
      type: DT_FLOAT
    }
  }
  attr {
    key: "_class"
    value {
      list {
        s: "loc:@var"
      }
    }
  }
}
node {
  name: "var_add/value"
  op: "Const"
  attr {
    key: "dtype"
    value {
      type: DT_FLOAT
    }
  }
  attr {
    key: "value"
    value {
      tensor {
        dtype: DT_FLOAT
        tensor_shape {
        }
        float_val: 1.0
      }
    }
  }
}
node {
  name: "var_add"
  op: "AssignAdd"
  input: "var"
  input: "var_add/value"
  attr {
    key: "T"
    value {
      type: DT_FLOAT
    }
  }
  attr {
    key: "_class"
    value {
      list {
        s: "loc:@var"
      }
    }
  }
  attr {
    key: "use_locking"
    value {
      b: false
    }
  }
}
node {
  name: "z"
  op: "IdentityN"
  input: "w"
  input: "x"
  input: "y"
  input: "^var_add"
  attr {
    key: "T"
    value {
      list {
        type: DT_FLOAT
        type: DT_FLOAT
        type: DT_FLOAT
      }
    }
  }
}
node {
  name: "a"
  op: "Add"
  input: "z:1"
  input: "z:2"
  attr {
    key: "T"
    value {
      type: DT_FLOAT
    }
  }
}
versions {
  producer: 230
}

# Test non zero index output tensors as feeds. Original ops where their outputs
# are replaced with feeds are preserved and args and rets are lifted to the
# function. Rets that happen to coincide with a feed should have its value be
# of the feed.
#
# CHECK-LABEL: func @main
# CHECK-SAME:  (%[[ARG_0:.*]]: tensor<f32>, %[[ARG_1:.*]]: tensor<f32>) -> (tensor<f32>, tensor<f32>, tensor<*xf32>)
# CHECK-SAME:  control_outputs = ""
# CHECK-SAME:  inputs = "z:1,z:2"
# CHECK-SAME:  outputs = "z:2,z:1,a:0"
# CHECK:           %{{.*}}, %[[ASSIGN_ADD_CTRL:.*]] = tf_executor.island wraps "tf.AssignAdd"
# CHECK:           %{{.*}}, %{{.*}} = tf_executor.island(%[[ASSIGN_ADD_CTRL]]) wraps "tf.IdentityN"
# CHECK:           %[[ADD:.*]], %{{.*}} = tf_executor.island wraps "tf.Add"(%[[ARG_0]], %[[ARG_1]])
# CHECK:           tf_executor.fetch %[[ARG_1]], %[[ARG_0]], %[[ADD]]

# Test when non zero index output tensors are feeds, remaining ops that are
# unreachable are pruned if pruning is enabled.
#
# PRUNE-LABEL: func @main
# PRUNE-SAME:  (%[[ARG_0:.*]]: tensor<f32>, %[[ARG_1:.*]]: tensor<f32>) -> (tensor<f32>, tensor<f32>, tensor<*xf32>)
# PRUNE-SAME:  control_outputs = ""
# PRUNE-SAME:  inputs = "z:1,z:2"
# PRUNE-SAME:  outputs = "z:2,z:1,a:0"
# PRUNE-NOT:       "tf.Const"
# PRUNE-NOT:       "tf.VariableV2"
# PRUNE-NOT:       "tf.Assign"
# PRUNE-NOT:       "tf.Identity"
# PRUNE-NOT:       "tf.AssignAdd"
# PRUNE-NOT:       "tf.IdentityN"
# PRUNE:           %[[ADD:.*]], %{{.*}} = tf_executor.island wraps "tf.Add"(%[[ARG_0]], %[[ARG_1]])
# PRUNE:           tf_executor.fetch %[[ARG_1]], %[[ARG_0]], %[[ADD]]

# Test when non zero index output tensors are feeds, remaining ops that are
# unreachable are preserved if pruning is not enabled.
#
# PRESERVE-LABEL: func @main
# PRESERVE-SAME:  (%[[ARG_0:.*]]: tensor<f32>, %[[ARG_1:.*]]: tensor<f32>) -> (tensor<*xf32>, tensor<*xf32>)
# PRESERVE-SAME:  control_outputs = ""
# PRESERVE-SAME:  inputs = "z:1,z:2"
# PRESERVE-SAME:  outputs = "z:0,a:0"
# PRESERVE:           %{{.*}}, %[[ASSIGN_ADD_CTRL:.*]] = tf_executor.island wraps "tf.AssignAdd"
# PRESERVE:           %[[IDENTITY_N:.*]]:3, %{{.*}} = tf_executor.island(%[[ASSIGN_ADD_CTRL]]) wraps "tf.IdentityN"
# PRESERVE:           %[[ADD:.*]], %{{.*}} = tf_executor.island wraps "tf.Add"(%[[ARG_0]], %[[ARG_1]])
# PRESERVE:           tf_executor.fetch %[[IDENTITY_N]]#0, %[[ADD]]
